// TODO: A similar note was provided to the header-only DBN: parameters files should be used (for creating / training)


// STL
#include <string> // std::string
#include <tuple>  // std::tie()

// jScience
#include "jScience/linalg.hpp" // Matrix<>

// stats++
#include "statsxx/data.hpp"                 // read_datafile()
#include "statsxx/machine_learning/DBN.hpp" // machine_learning::DBN

// this
#include "_DBN.hpp" // create_DBN(), train_DBN(), test_DBN()
#include "init.hpp" // generate_dataset()


//
// USAGE: dbn filename_param
//
//     filename_param :: parameters filename
//
int main(
         int argc,
         char* argv[]
         )
{
    //=========================================================
    // INITIALIZATION
    //=========================================================

    std::string filename_param = std::string(argv[1]);

    
    // ----- PARAMETERS -----
    std::string         dbn_file;
    // =====
    bool                create;
    // -----
    std::vector<int>    create_nunits;     
    std::vector<int>    create_units_type; 
    // =====
    bool                train;
    // -----
    std::string         train_X_file;
    // -----
    int                 train_nepoch;
    int                 train_npts_per_batch;
    // -----
    int                 train_K_start;
    int                 train_K_end;
    double              train_K_rate;
    // -----
    bool                train_sample_v;
    bool                train_sample_hdata;
    // -----
    double              train_lr;
    std::vector<double> train_lr_dot; 
    std::vector<double> train_lr_adj;    
    double              train_lr_min;
    double              train_lr_max;
    // -----
    double              train_p_start;
    double              train_p_end;
    double              train_p_rate;
    std::vector<double> train_p_dot;      
    std::vector<double> train_p_adj;        
    // -----
    double              train_w_penalty;
    // -----
    //    int free_energy_npts;
    //    int free_energy_nepoch;
    //    int free_energy_window;
    // -----
    int                 train_convg_nno_improvement;
    double              train_convg_max_inc_max_w;
    // =====
    bool                test;
    // -----
    std::string         test_X_file;
    std::string         test_H_file;
    
    std::tie(
             dbn_file,
             // =====
             create,
             // -----
             create_nunits,
             create_units_type,
             // =====
             train,
             // -----
             train_X_file,
             // -----
             train_nepoch,
             train_npts_per_batch,
             // -----
             train_K_start,
             train_K_end,
             train_K_rate,
             // -----
             train_sample_v,
             train_sample_hdata,
             // -----
             train_lr,
             train_lr_dot,
             train_lr_adj,
             train_lr_min,
             train_lr_max,
             // -----
             train_p_start,
             train_p_end,
             train_p_rate,
             train_p_dot,
             train_p_adj,
             // -----
             train_w_penalty,
             // -----
             train_convg_nno_improvement,
             train_convg_max_inc_max_w,
             // =====
             test,
             // -----
             test_X_file,
             test_H_file
             ) = read_param_file(
                                 filename_param
                                 );

    //=========================================================
    // CREATE (OR READ) DBN
    //=========================================================
    
    // NOTE: this always occurs (create or read-in DBN)
    machine_learning::DBN dbn = create_DBN(
                                           create,
                                           // -----
                                           create_nunits,
                                           create_units_type,
                                           // -----
                                           dbn_file
                                           );
    
    //=========================================================
    // TRAIN
    //=========================================================
        
    if(train)
    {
        Matrix<double> X = read_datafile(
                                         train_X_file
                                         );
/*
        std::cout << "train_X_file        : " << train_X_file << '\n';
        std::cout << '\n';
        std::cout << "train_nepoch        : " << train_nepoch << '\n';
        std::cout << "train_npts_per_batch: " << train_npts_per_batch << '\n';
        std::cout << '\n';
        std::cout << "train_K_start: " << train_K_start << '\n';
        std::cout << "train_K_end: " << train_K_end << '\n';
        std::cout << "train_K_rate: " << train_K_rate << '\n';
        std::cout << '\n';
        std::cout << "train_sample_v: " << train_sample_v << '\n';
        std::cout << "train_sample_hdata: " << train_sample_hdata << '\n';
        std::cout << '\n';
        std::cout << "train_lr: " << train_lr << '\n';
        std::cout << "train_lr_dot.size(): " << train_lr_dot.size() << '\n';
        std::cout << "train_lr_adj.size(): " << train_lr_adj.size() << '\n';
        for(auto i = 0; i < train_lr_dot.size(); ++i)
        {
            std::cout << "train_lr_dot[" << i << "]: " << train_lr_dot[i] << '\n';
            std::cout << "train_lr_adj[" << i << "]: " << train_lr_adj[i] << '\n';
        }
        std::cout << "train_lr_min: " << train_lr_min << '\n';
        std::cout << "train_lr_max: " << train_lr_max << '\n';
        std::cout << '\n';
        std::cout << "train_p_start: " << train_p_start << '\n';
        std::cout << "train_p_end: " << train_p_end << '\n';
        std::cout << "train_p_rate: " << train_p_rate << '\n';
        std::cout << "train_p_dot.size(): " << train_p_dot.size() << '\n';
        std::cout << "train_p_adj.size(): " << train_p_adj.size() << '\n';
        for(auto i = 0; i < train_p_dot.size(); ++i)
        {
            std::cout << "train_p_dot[" << i << "]: " << train_p_dot[i] << '\n';
            std::cout << "train_p_adj[" << i << "]: " << train_p_adj[i] << '\n';
        }
        std::cout << '\n';
        std::cout << "train_w_penalty: " << train_w_penalty << '\n';
        std::cout << '\n';
        std::cout << "train_convg_nno_improvement: " << train_convg_nno_improvement << '\n';
        std::cout << "train_convg_max_inc_max_w: " << train_convg_max_inc_max_w << '\n';
        std::cout << '\n';
        std::cout << "dbn_file: " << dbn_file << '\n';
*/
        dbn = train_DBN(
                        dbn,
                        // =====
                        X,
                        // =====
                        train_nepoch,
                        train_npts_per_batch,
                        // -----
                        train_K_start,
                        train_K_end,
                        train_K_rate,
                        // -----
                        train_sample_v,
                        train_sample_hdata,
                        // -----
                        train_lr,
                        train_lr_dot,
                        train_lr_adj,
                        train_lr_min,
                        train_lr_max,
                        // -----
                        train_p_start,
                        train_p_end,
                        train_p_rate,
                        train_p_dot,
                        train_p_adj,
                        // -----
                        train_w_penalty,
                        // -----
                        //    int free_energy_npts;
                        //    int free_energy_nepoch;
                        //    int free_energy_window;
                        // -----
                        train_convg_nno_improvement,
                        train_convg_max_inc_max_w,
                        // =====
                        dbn_file
                        );
    }
    
    //=========================================================
    // TEST
    //=========================================================
    
    if(test)
    {
/*
        std::cout << "test_X_file: " << test_X_file << '\n';
        std::cout << "test_H_file: " << test_H_file << '\n';
*/
        Matrix<double> X = read_datafile(
                                         test_X_file
                                         );
        
        test_DBN(
                 dbn,
                 // =====
                 X,
                 // =====
                 test_H_file
                 );
    }
    
    
    return 0;
}

